{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W3D3_NetworkCausality/student/W3D3_Tutorial4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Neuromatch Academy 2020 -- Week 3 Day 5, Tutorial 4\n",
    "\n",
    "# Causality Day: Instrumental Variables\n",
    "\n",
    "**Content creators**: Ari Benjamin, Tony Liu, Konrad Kording\n",
    "\n",
    "**Content reviewers**: Mike X Cohen, Madineh Sarvestani, Ella Batty, Michael Waskom"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Tutorial objectives\n",
    "\n",
    "This is our final tutorial on our day of examining causality. Below is the high level outline of what we've covered today, with the sections we will focus on in this notebook in bold:\n",
    "\n",
    "1.   Master definitions of causality\n",
    "2.   Understand that estimating causality is possible\n",
    "3.   Learn 4 different methods and understand when they fail\n",
    "    1. perturbations\n",
    "    2. correlations\n",
    "    3. simultaneous fitting/regression\n",
    "    4. **instrumental variables**\n",
    "\n",
    "### Notebook 4 Objectives\n",
    "\n",
    "In tutorial 3 we saw that even more sophisticated techniques such as simultaneous fitting fail to capture causality in the presence of omitted variable bias. So what techniques are there for us to obtain valid causal measurements when we can't perturb the system? Here we will:\n",
    "\n",
    "- learn about **instrumental variables,** a method that does not require experimental data for valid causal analysis\n",
    "- explore benefits of instrumental variable analysis and limitations\n",
    "    - addresses **omitted variable bias** seen in regression\n",
    "    - less efficient in terms of sample size than other techniques\n",
    "    - requires a particular form of randomness in the system in order for causal effects to be identified"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:07.843681Z",
     "iopub.status.busy": "2021-05-25T01:25:07.843105Z",
     "iopub.status.idle": "2021-05-25T01:25:08.568278Z",
     "shell.execute_reply": "2021-05-25T01:25:08.566092Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "from sklearn.multioutput import MultiOutputRegressor\n",
    "from sklearn.linear_model import LinearRegression, Lasso"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.573940Z",
     "iopub.status.busy": "2021-05-25T01:25:08.573366Z",
     "iopub.status.idle": "2021-05-25T01:25:08.657911Z",
     "shell.execute_reply": "2021-05-25T01:25:08.657115Z"
    }
   },
   "outputs": [],
   "source": [
    "#@title Figure settings\n",
    "import ipywidgets as widgets       # interactive display\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.690064Z",
     "iopub.status.busy": "2021-05-25T01:25:08.676751Z",
     "iopub.status.idle": "2021-05-25T01:25:08.698776Z",
     "shell.execute_reply": "2021-05-25T01:25:08.698236Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Helper functions\n",
    "\n",
    "\n",
    "def sigmoid(x):\n",
    "    \"\"\"\n",
    "    Compute sigmoid nonlinearity element-wise on x.\n",
    "\n",
    "    Args:\n",
    "        x (np.ndarray): the numpy data array we want to transform\n",
    "    Returns\n",
    "        (np.ndarray): x with sigmoid nonlinearity applied\n",
    "    \"\"\"\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "\n",
    "def logit(x):\n",
    "    \"\"\"\n",
    "\n",
    "    Applies the logit (inverse sigmoid) transformation\n",
    "\n",
    "    Args:\n",
    "        x (np.ndarray): the numpy data array we want to transform\n",
    "    Returns\n",
    "        (np.ndarray): x with logit nonlinearity applied\n",
    "    \"\"\"\n",
    "    return np.log(x/(1-x))\n",
    "\n",
    "\n",
    "def create_connectivity(n_neurons, random_state=42, p=0.9):\n",
    "    \"\"\"\n",
    "    Generate our nxn causal connectivity matrix.\n",
    "\n",
    "    Args:\n",
    "        n_neurons (int): the number of neurons in our system.\n",
    "        random_state (int): random seed for reproducibility\n",
    "\n",
    "    Returns:\n",
    "        A (np.ndarray): our 0.1 sparse connectivity matrix\n",
    "    \"\"\"\n",
    "    np.random.seed(random_state)\n",
    "    A_0 = np.random.choice([0, 1], size=(n_neurons, n_neurons), p=[p, 1 - p])\n",
    "\n",
    "    # set the timescale of the dynamical system to about 100 steps\n",
    "    _, s_vals, _ = np.linalg.svd(A_0)\n",
    "    A = A_0 / (1.01 * s_vals[0])\n",
    "\n",
    "    # _, s_val_test, _ = np.linalg.svd(A)\n",
    "    # assert s_val_test[0] < 1, \"largest singular value >= 1\"\n",
    "\n",
    "    return A\n",
    "\n",
    "\n",
    "def see_neurons(A, ax):\n",
    "    \"\"\"\n",
    "    Visualizes the connectivity matrix.\n",
    "\n",
    "    Args:\n",
    "        A (np.ndarray): the connectivity matrix of shape (n_neurons, n_neurons)\n",
    "        ax (plt.axis): the matplotlib axis to display on\n",
    "\n",
    "    Returns:\n",
    "        Nothing, but visualizes A.\n",
    "    \"\"\"\n",
    "    A = A.T  # make up for opposite connectivity\n",
    "    n = len(A)\n",
    "    ax.set_aspect('equal')\n",
    "    thetas = np.linspace(0, np.pi * 2, n,endpoint=False)\n",
    "    x, y = np.cos(thetas), np.sin(thetas),\n",
    "    ax.scatter(x, y, c='k',s=150)\n",
    "    A = A / A.max()\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            if A[i, j] > 0:\n",
    "                ax.arrow(x[i], y[i], x[j] - x[i], y[j] - y[i], color='k', alpha=A[i, j], head_width=.15,\n",
    "                        width = A[i,j] / 25, shape='right', length_includes_head=True)\n",
    "    ax.axis('off')\n",
    "\n",
    "\n",
    "def simulate_neurons(A, timesteps, random_state=42):\n",
    "    \"\"\"\n",
    "    Simulates a dynamical system for the specified number of neurons and timesteps.\n",
    "\n",
    "    Args:\n",
    "        A (np.array): the connectivity matrix\n",
    "        timesteps (int): the number of timesteps to simulate our system.\n",
    "        random_state (int): random seed for reproducibility\n",
    "\n",
    "    Returns:\n",
    "        - X has shape (n_neurons, timeteps).\n",
    "    \"\"\"\n",
    "    np.random.seed(random_state)\n",
    "\n",
    "    n_neurons = len(A)\n",
    "    X = np.zeros((n_neurons, timesteps))\n",
    "\n",
    "    for t in range(timesteps - 1):\n",
    "        # solution\n",
    "        epsilon = np.random.multivariate_normal(np.zeros(n_neurons), np.eye(n_neurons))\n",
    "        X[:, t + 1] = sigmoid(A.dot(X[:, t]) + epsilon)\n",
    "\n",
    "        assert epsilon.shape == (n_neurons,)\n",
    "    return X\n",
    "\n",
    "\n",
    "def correlation_for_all_neurons(X):\n",
    "  \"\"\"Computes the connectivity matrix for the all neurons using correlations\n",
    "\n",
    "    Args:\n",
    "        X: the matrix of activities\n",
    "\n",
    "    Returns:\n",
    "        estimated_connectivity (np.ndarray): estimated connectivity for the selected neuron, of shape (n_neurons,)\n",
    "  \"\"\"\n",
    "  n_neurons = len(X)\n",
    "  S = np.concatenate([X[:, 1:], X[:, :-1]], axis=0)\n",
    "  R = np.corrcoef(S)[:n_neurons, n_neurons:]\n",
    "  return R\n",
    "\n",
    "\n",
    "def get_sys_corr(n_neurons, timesteps, random_state=42, neuron_idx=None):\n",
    "    \"\"\"\n",
    "    A wrapper function for our correlation calculations between A and R.\n",
    "\n",
    "    Args:\n",
    "        n_neurons (int): the number of neurons in our system.\n",
    "        timesteps (int): the number of timesteps to simulate our system.\n",
    "        random_state (int): seed for reproducibility\n",
    "        neuron_idx (int): optionally provide a neuron idx to slice out\n",
    "\n",
    "    Returns:\n",
    "        A single float correlation value representing the similarity between A and R\n",
    "    \"\"\"\n",
    "\n",
    "    A = create_connectivity(n_neurons, random_state)\n",
    "    X = simulate_neurons(A, timesteps)\n",
    "\n",
    "    R = correlation_for_all_neurons(X)\n",
    "\n",
    "    return np.corrcoef(A.flatten(), R.flatten())[0, 1]\n",
    "\n",
    "\n",
    "def print_corr(v1, v2, corrs, idx_dict):\n",
    "    \"\"\"Helper function for formatting print statements for correlations\"\"\"\n",
    "    text_dict = {'Z':'taxes', 'T':'# cigarettes', 'C':'SES status', 'Y':'birth weight'}\n",
    "    print(\"Correlation between {} and {} ({} and {}): {:.3f}\".format(v1, v2, text_dict[v1], text_dict[v2], corrs[idx_dict[v1], idx_dict[v2]]))\n",
    "\n",
    "\n",
    "def get_regression_estimate(X, neuron_idx=None):\n",
    "    \"\"\"\n",
    "    Estimates the connectivity matrix using lasso regression.\n",
    "\n",
    "    Args:\n",
    "        X (np.ndarray): our simulated system of shape (n_neurons, timesteps)\n",
    "        neuron_idx (int): optionally provide a neuron idx to compute connectivity for\n",
    "    Returns:\n",
    "        V (np.ndarray): estimated connectivity matrix of shape (n_neurons, n_neurons).\n",
    "                        if neuron_idx is specified, V is of shape (n_neurons,).\n",
    "    \"\"\"\n",
    "    n_neurons = X.shape[0]\n",
    "\n",
    "    # Extract Y and W as defined above\n",
    "    W = X[:, :-1].transpose()\n",
    "    if neuron_idx is None:\n",
    "        Y = X[:, 1:].transpose()\n",
    "    else:\n",
    "        Y = X[[neuron_idx], 1:].transpose()\n",
    "\n",
    "    # apply inverse sigmoid transformation\n",
    "    Y = logit(Y)\n",
    "\n",
    "    # fit multioutput regression\n",
    "    regression = MultiOutputRegressor(Lasso(fit_intercept=False, alpha=0.01), n_jobs=-1)\n",
    "\n",
    "    regression.fit(W,Y)\n",
    "\n",
    "    if neuron_idx is None:\n",
    "        V = np.zeros((n_neurons, n_neurons))\n",
    "        for i, estimator in enumerate(regression.estimators_):\n",
    "            V[i, :] = estimator.coef_\n",
    "    else:\n",
    "        V = regression.estimators_[0].coef_\n",
    "\n",
    "    return V\n",
    "\n",
    "\n",
    "def get_regression_corr(n_neurons, timesteps, random_state, observed_ratio, regression_args, neuron_idx=None):\n",
    "    \"\"\"\n",
    "    A wrapper function for our correlation calculations between A and the V estimated\n",
    "    from regression.\n",
    "\n",
    "    Args:\n",
    "        n_neurons (int): the number of neurons in our system.\n",
    "        timesteps (int): the number of timesteps to simulate our system.\n",
    "        random_state (int): seed for reproducibility\n",
    "        observed_ratio (float): the proportion of n_neurons observed, must be betweem 0 and 1.\n",
    "        regression_args (dict): dictionary of lasso regression arguments and hyperparameters\n",
    "        neuron_idx (int): optionally provide a neuron idx to compute connectivity for\n",
    "\n",
    "    Returns:\n",
    "        A single float correlation value representing the similarity between A and R\n",
    "    \"\"\"\n",
    "    assert (observed_ratio > 0) and (observed_ratio <= 1)\n",
    "\n",
    "    A = create_connectivity(n_neurons, random_state)\n",
    "    X = simulate_neurons(A, timesteps)\n",
    "\n",
    "    sel_idx = np.clip(int(n_neurons*observed_ratio), 1, n_neurons)\n",
    "\n",
    "    sel_X = X[:sel_idx, :]\n",
    "    sel_A = A[:sel_idx, :sel_idx]\n",
    "\n",
    "    sel_V = get_regression_estimate(sel_X, neuron_idx=neuron_idx)\n",
    "    if neuron_idx is None:\n",
    "        return np.corrcoef(sel_A.flatten(), sel_V.flatten())[1, 0]\n",
    "    else:\n",
    "        return np.corrcoef(sel_A[neuron_idx, :], sel_V)[1, 0]\n",
    "\n",
    "\n",
    "def compare_iv_estimate_to_regression(observed_ratio):\n",
    "  \"\"\"\n",
    "  A wrapper function to compare IV and Regressor performance as a function of observed neurons\n",
    "\n",
    "  Args:\n",
    "        observed_ratio(list): a list of different observed ratios (out of the whole system)\n",
    "  \"\"\"\n",
    "\n",
    "  #Let's compare IV estimates to our regression estimates, uncomment the code below\n",
    "\n",
    "  reg_corrs = np.zeros((len(observed_ratio),))\n",
    "  iv_corrs = np.zeros((len(observed_ratio),))\n",
    "  for j, ratio in enumerate(observed_ratio):\n",
    "      print(ratio)\n",
    "      sel_idx = int(ratio * n_neurons)\n",
    "\n",
    "      sel_X = X[:sel_idx, :]\n",
    "      sel_Z = X[:sel_idx, :]\n",
    "      sel_A = A[:sel_idx, :sel_idx]\n",
    "\n",
    "      sel_reg_V = get_regression_estimate(sel_X)\n",
    "      reg_corrs[j] = np.corrcoef(sel_A.flatten(), sel_reg_V.flatten())[1, 0]\n",
    "\n",
    "      sel_iv_V = get_iv_estimate_network(sel_X, sel_Z)\n",
    "      iv_corrs[j] = np.corrcoef(sel_A.flatten(), sel_iv_V.flatten())[1, 0]\n",
    "\n",
    "  # Plotting IV vs lasso performance\n",
    "  plt.plot(observed_ratio, reg_corrs)\n",
    "  plt.plot(observed_ratio, iv_corrs)\n",
    "  plt.xlim([1, 0.2])\n",
    "  plt.ylabel(\"Connectivity matrices correlation with truth\")\n",
    "  plt.xlabel(\"Fraction of observed variables\")\n",
    "  plt.title(\"IV and lasso performance as a function of observed neuron ratio\")\n",
    "  plt.legend(['Regression', 'IV'])\n",
    "\n",
    "def plot_neural_activity(X):\n",
    "  \"\"\"Plot first 10 timesteps of neural activity\n",
    "\n",
    "  Args:\n",
    "    X (ndarray): neural activity (n_neurons by timesteps)\n",
    "\n",
    "  \"\"\"\n",
    "  f, ax = plt.subplots()\n",
    "  im = ax.imshow(X[:, :10], aspect='auto')\n",
    "  divider = make_axes_locatable(ax)\n",
    "  cax1 = divider.append_axes(\"right\", size=\"5%\", pad=0.15)\n",
    "  plt.colorbar(im, cax=cax1)\n",
    "  ax.set(xlabel='Timestep', ylabel='Neuron', title='Simulated Neural Activity')\n",
    "\n",
    "\n",
    "def compare_granger_connectivity(A, reject_null, selected_neuron):\n",
    "  \"\"\"Plot granger connectivity vs true\n",
    "\n",
    "  Args:\n",
    "    A (ndarray): true connectivity (n_neurons by n_neurons)\n",
    "    reject_null (list): outcome of granger causality, length n_neurons\n",
    "    selecte_neuron (int): the neuron we are plotting connectivity from\n",
    "\n",
    "  \"\"\"\n",
    "  fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
    "\n",
    "  im = axs[0].imshow(A[:, [selected_neuron]], cmap='coolwarm', aspect='auto')\n",
    "  plt.colorbar(im, ax = axs[0])\n",
    "  axs[0].set_xticks([0])\n",
    "  axs[0].set_xticklabels([selected_neuron])\n",
    "  axs[0].title.set_text(\"True connectivity for neuron {}\".format(selected_neuron))\n",
    "\n",
    "  im = axs[1].imshow(np.array([reject_null]).transpose(), cmap='coolwarm', aspect='auto')\n",
    "  plt.colorbar(im, ax=axs[1])\n",
    "  axs[1].set_xticks([0])\n",
    "  axs[1].set_xticklabels([selected_neuron])\n",
    "  axs[1].title.set_text(\"Granger causality connectivity for neuron {}\".format(selected_neuron))\n",
    "\n",
    "\n",
    "def plot_performance_vs_eta(etas, corr_data):\n",
    "  \"\"\" Plot IV estimation performance as a function of instrument strength\n",
    "\n",
    "    Args:\n",
    "      etas (list): list of instrument strengths\n",
    "      corr_data (ndarray): n_trials x len(etas) array where each element is the correlation\n",
    "        between true and estimated connectivity matries for that trial and\n",
    "        instrument strength\n",
    "\n",
    "  \"\"\"\n",
    "  corr_mean = corr_data.mean(axis=0)\n",
    "  corr_std = corr_data.std(axis=0)\n",
    "\n",
    "  plt.plot(etas, corr_mean)\n",
    "  plt.fill_between(etas,\n",
    "              corr_mean - corr_std,\n",
    "              corr_mean + corr_std,\n",
    "              alpha=.2)\n",
    "  plt.xlim([etas[0], etas[-1]])\n",
    "  plt.title(\"IV performance as a function of instrument strength\")\n",
    "  plt.ylabel(\"Correlation b.t. IV and true connectivity\")\n",
    "  plt.xlabel(\"Strength of instrument (eta)\")\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 1: Instrumental Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 517
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.707253Z",
     "iopub.status.busy": "2021-05-25T01:25:08.706688Z",
     "iopub.status.idle": "2021-05-25T01:25:08.755626Z",
     "shell.execute_reply": "2021-05-25T01:25:08.756101Z"
    },
    "outputId": "61f30e56-ae7d-452b-8877-3bca7bef5f88"
   },
   "outputs": [],
   "source": [
    "#@title Video 1: Instrumental Variables\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"0gkav6BS4-w\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "If there is randomness naturally occurring in the system *that we can observe*, this in effect becomes the perturbations we can use to recover causal effects. This is called an **instrumental variable**. At high level, an instrumental variable must\n",
    "\n",
    "\n",
    "1.   be observable\n",
    "2.   effect a covariate you care about\n",
    "3.   **not** effect the outcome, except through the covariate\n",
    "\n",
    "\n",
    "It's rare to find these things in the wild, but when you do it's very powerful. \n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "![image.png]()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 1.1: A non-neuro example of an IV\n",
    "A classic example is estimating the effect of smoking cigarettes while pregnant on the birth weight of the infant. There is a (negative) correlation, but is it causal? Unfortunately many confounds affect both birth weight and smoking. Wealth is a big one.\n",
    "\n",
    "Instead of controlling for everything imaginable, one can find an IV. Here the instrumental variable is **state taxes on tobacco**. These\n",
    "\n",
    "\n",
    "1.   Are observable\n",
    "2.   Affect tobacco consumption\n",
    "3.   Don't affect birth weight except through tobacco\n",
    "\n",
    "By using the power of IV techniques, you can determine the causal effect without exhaustively controlling for everything.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Let's represent our tobacco example above with the following notation:\n",
    "\n",
    "- $Z_{\\text{taxes}}$: our tobacco tax **instrument**, which only affects an individual's tendency to smoke while pregnant within our system\n",
    "- $T_{\\text{smoking}}$: number of cigarettes smoked per day while pregnant, our \"treatment\" if this were a randomized trial\n",
    "- $C_{\\text{SES}}$: socioeconomic status (higher means weathier), a **confounder** if it is not observed\n",
    "- $Y_{\\text{birthweight}}$: child birthweight in grams, our outcome of interest\n",
    "\n",
    "Let's suppose we have the following function for our system:\n",
    "\n",
    "$Y_{\\text{birthweight}} = 3000 + C_{\\text{SES}} - 2T_{\\text{smoking}},$\n",
    "\n",
    "with the additional fact that $C_{\\text{SES}}$ is negatively correlated with $T_{\\text{smoking}}$. \n",
    "\n",
    "The causal effect we wish to estimate is the coefficient $-2$ for $T_{\\text{smoking}}$, which means that if a mother smokes one additional cigarette per day while pregnant her baby will be 2 grams lighter at birth.\n",
    "\n",
    "We've provided a covariance matrix with the desired structure in the code cell below, so please run it to look at the correlations between our variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.764513Z",
     "iopub.status.busy": "2021-05-25T01:25:08.763682Z",
     "iopub.status.idle": "2021-05-25T01:25:08.768526Z",
     "shell.execute_reply": "2021-05-25T01:25:08.768045Z"
    },
    "outputId": "278586d8-741d-4eef-9cd7-5faa6b543955"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to see correlations with C\n",
    "# run this code below to generate our setup\n",
    "idx_dict = {\n",
    "    'Z': 0,\n",
    "    'T': 1,\n",
    "    'C': 2,\n",
    "    'Y': 3\n",
    "}\n",
    "# vars:             Z    T    C\n",
    "covar = np.array([[1.0, 0.5, 0.0],  # Z\n",
    "                  [0.5, 1.0, -0.5],  # T\n",
    "                  [0.0, -0.5, 1.0]])  # C\n",
    "# vars:  Z  T  C\n",
    "means = [0, 5, 2]\n",
    "\n",
    "# generate some data\n",
    "np.random.seed(42)\n",
    "data = np.random.multivariate_normal(mean=means, cov=2 * covar, size=2000)\n",
    "\n",
    "# generate Y from our equation above\n",
    "Y = 3000 + data[:, idx_dict['C']] - (2 * (data[:, idx_dict['T']]))\n",
    "\n",
    "data = np.concatenate([data, Y.reshape(-1, 1)], axis=1)\n",
    "\n",
    "Z = data[:, [idx_dict['Z']]]\n",
    "T = data[:, [idx_dict['T']]]\n",
    "C = data[:, [idx_dict['C']]]\n",
    "Y = data[:, [idx_dict['Y']]]\n",
    "\n",
    "corrs = np.corrcoef(data.transpose())\n",
    "\n",
    "print_corr('C', 'T', corrs, idx_dict)\n",
    "print_corr('C', 'Y', corrs, idx_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We see what is exactly represented in our graph above: $C_{\\text{SES}}$ is correlated with both $T_{\\text{smoking}}$ and $Y_{\\text{birthweight}}$, so $C_{\\text{SES}}$ is a potential confounder if not included in our analysis. Let's say that it is difficult to observe and quantify $C_{\\text{SES}}$, so we do not have it available to regress against. This is another example of the **omitted variable bias** we saw in the last tutorial.\n",
    "\n",
    "What about $Z_{\\text{taxes}}$? Does it satisfy conditions 1, 2, and 3 of an instrument?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 87
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.774624Z",
     "iopub.status.busy": "2021-05-25T01:25:08.773255Z",
     "iopub.status.idle": "2021-05-25T01:25:08.776512Z",
     "shell.execute_reply": "2021-05-25T01:25:08.776009Z"
    },
    "outputId": "0eee211f-4d87-4ebc-d78a-0f53483a3b3d"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to see correlations of Z\n",
    "print(\"Condition 2?\")\n",
    "print_corr('Z', 'T', corrs, idx_dict)\n",
    "print(\"Condition 3?\")\n",
    "print_corr('Z', 'C', corrs, idx_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Perfect! We see that $Z_{\\text{taxes}}$ is correlated with $T_{\\text{smoking}}$ (#2) but is uncorrelated with $C_{\\text{SES}}$ (#3).  $Z_\\text{taxes}$ is also observable (#1), so we've satisfied our three criteria for an instrument:\n",
    "\n",
    "1.   $Z_\\text{taxes}$ is observable\n",
    "2.   $Z_\\text{taxes}$ affects $T_{\\text{smoking}}$\n",
    "3.   $Z_\\text{taxes}$ doesn't affect $Y_{\\text{birthweight}}$ except through $T_{\\text{smoking}}$ (ie $Z_\\text{taxes}$ doesn't affect or is affected by $C_\\text{SES}$)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 1.2: How IV works, at high level\n",
    "\n",
    "The easiest way to imagine IV is that the instrument is **an observable source of \"randomness\"** that affects the treatment. In this way it's similar to the interventions we talked about in Tutorial 1.\n",
    "\n",
    "But how do you actually use the instrument? The key is that we need to extract **the component of the treatment that is due only to the effect of the instrument**. We will call this component $\\hat{T}$.\n",
    "$$\n",
    "\\hat{T}\\leftarrow \\text{The unconfounded component of }T\n",
    "$$\n",
    "Getting $\\hat{T}$ is fairly simple. It is simply the predicted value of $T$ found in a regression that has only the instrument $Z$ as input.\n",
    "\n",
    "Once we have the unconfounded component in hand, getting the causal effect is as easy as regressing the outcome on $\\hat{T}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 1.3: IV estimation using two-stage least squares\n",
    "\n",
    "The fundamental technique for instrumental variable estimation is **two-stage least squares**. \n",
    "\n",
    "We run two regressions:\n",
    "\n",
    "1. The first stage gets  $\\hat{T}_{\\text{smoking}}$ by regressing $T_{\\text{smoking}}$ on $Z_\\text{taxes}$, fitting the parameter $\\hat{\\alpha}$:\n",
    "\n",
    "$$\n",
    "\\hat{T}_{\\text{smoking}} = \\hat{\\alpha} Z_\\text{taxes}\n",
    "$$\n",
    "\n",
    "2. The second stage then regresses $Y_{\\text{birthweight}}$ on $\\hat{T}_{\\text{smoking}}$ to obtain an estimate $\\hat{\\beta}$ of the causal effect:\n",
    "\n",
    "$$\n",
    "\\hat{Y}_{\\text{birthweight}} = \\hat{\\beta} \\hat{T}_{\\text{smoking}} \n",
    "$$\n",
    "\n",
    "The first stage estimates the **unconfounded component** of $T_{\\text{smoking}}$ (ie, unaffected by the confounder $C_{\\text{SES}}$), as we discussed above. \n",
    "\n",
    "Then, the second stage uses this unconfounded component $\\hat{T}_{\\text{smoking}}$ to estimate the effect of smoking on $\\hat{Y}_{\\text{birthweight}}$. \n",
    "\n",
    "We will explore how all this works in the next two exercises.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Section 1.3.1: Least squares regression stage 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 517
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.781950Z",
     "iopub.status.busy": "2021-05-25T01:25:08.781359Z",
     "iopub.status.idle": "2021-05-25T01:25:08.828589Z",
     "shell.execute_reply": "2021-05-25T01:25:08.829067Z"
    },
    "outputId": "16c25233-6c85-4dd6-a148-f27cf8d25f10"
   },
   "outputs": [],
   "source": [
    "#@title Video 2: Stage 1\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"4WT0KrySRTg\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Exercise 1: Compute regression stage 1\n",
    "\n",
    "Let's run the regression of $T_{\\text{smoking}}$ on $Z_\\text{taxes}$ to compute $\\hat{T}_{\\text{smoking}}$. We will then check whether our estimate is still confounded with $C_{\\text{SES}}$ by comparing the correlation of $C_{\\text{SES}}$  with $T_{\\text{smoking}}$ vs $\\hat{T}_{\\text{smoking}}$.\n",
    "\n",
    "### Suggestions\n",
    "\n",
    "- use the `LinearRegression()` model, already imported from scikit-learn\n",
    "    - use `fit_intercept=True` as the only parameter setting\n",
    "- be sure to check the ordering of the parameters passed to `LinearRegression.fit()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.835184Z",
     "iopub.status.busy": "2021-05-25T01:25:08.833805Z",
     "iopub.status.idle": "2021-05-25T01:25:08.835852Z",
     "shell.execute_reply": "2021-05-25T01:25:08.836312Z"
    }
   },
   "outputs": [],
   "source": [
    "def fit_first_stage(T, Z):\n",
    "    \"\"\"\n",
    "    Estimates T_hat as the first stage of a two-stage least squares.\n",
    "\n",
    "    Args:\n",
    "        T (np.ndarray): our observed, possibly confounded, treatment of shape (n, 1)\n",
    "        Z (np.ndarray): our observed instruments of shape (n, 1)\n",
    "\n",
    "    Returns\n",
    "        T_hat (np.ndarray): our estimate of the unconfounded portion of T\n",
    "    \"\"\"\n",
    "\n",
    "    ############################################################################\n",
    "    ## Insert your code here to fit the first stage of the 2-stage least squares\n",
    "    ## estimate.\n",
    "    ## Fill out function and remove\n",
    "    raise NotImplementedError('Please complete fit_first_stage function')\n",
    "    ############################################################################\n",
    "\n",
    "    # Initialize linear regression model\n",
    "    stage1 = LinearRegression(...)\n",
    "\n",
    "    # Fit linear regression model\n",
    "    stage1.fit(...)\n",
    "\n",
    "    # Predict T_hat using linear regression model\n",
    "    T_hat = stage1.predict(...)\n",
    "\n",
    "    return T_hat\n",
    "\n",
    "\n",
    "# Uncomment below to test your function\n",
    "# T_hat = fit_first_stage(T, Z)\n",
    "\n",
    "# T_C_corr = np.corrcoef(T.transpose(), C.transpose())[0, 1]\n",
    "# T_hat_C_corr = np.corrcoef(T_hat.transpose(), C.transpose())[0, 1]\n",
    "# print(\"Correlation between T and C: {:.3f}\".format(T_C_corr))\n",
    "# print(\"Correlation between T_hat and C: {:.3f}\".format(T_hat_C_corr))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.842839Z",
     "iopub.status.busy": "2021-05-25T01:25:08.841980Z",
     "iopub.status.idle": "2021-05-25T01:25:08.846946Z",
     "shell.execute_reply": "2021-05-25T01:25:08.846385Z"
    },
    "outputId": "209e0863-878f-4d19-f11f-33241a9f6919"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial4_Solution_3985425a.py)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "You should see a correlation between $T$ and $C$ of `-0.483` and between $\\hat{T}$ and $C$ of `0.009`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Section 1.3.2: Least squares regression stage 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 517
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.852521Z",
     "iopub.status.busy": "2021-05-25T01:25:08.851975Z",
     "iopub.status.idle": "2021-05-25T01:25:08.897393Z",
     "shell.execute_reply": "2021-05-25T01:25:08.897869Z"
    },
    "outputId": "84e1e3b4-07c4-426d-96df-743e35b49777"
   },
   "outputs": [],
   "source": [
    "#@title Video 3: Stage 2\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"F-_m_Vgv75I\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Exercise 2: Compute the IV estimate\n",
    "\n",
    "Now let's implement the second stage! Complete the `fit_second_stage()` function below. We will again use a linear regression model with an intercept. We will then use the function from Exercise 1 (`fit_first_stage`) and this function to estimate the full two-stage regression model. We will obtain the estimated causal effect of the number of cigarettes ($T$) on birth weight ($Y$).\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.901212Z",
     "iopub.status.busy": "2021-05-25T01:25:08.900481Z",
     "iopub.status.idle": "2021-05-25T01:25:08.905362Z",
     "shell.execute_reply": "2021-05-25T01:25:08.904804Z"
    }
   },
   "outputs": [],
   "source": [
    "def fit_second_stage(T_hat, Y):\n",
    "    \"\"\"\n",
    "    Estimates a scalar causal effect from 2-stage least squares regression using\n",
    "    an instrument.\n",
    "\n",
    "    Args:\n",
    "        T_hat (np.ndarray): the output of the first stage regression\n",
    "        Y (np.ndarray): our observed response (n, 1)\n",
    "\n",
    "    Returns:\n",
    "        beta (float): the estimated causal effect\n",
    "    \"\"\"\n",
    "    ############################################################################\n",
    "    ## Insert your code here to fit the second stage of the 2-stage least squares\n",
    "    ## estimate.\n",
    "    ## Fill out function and remove\n",
    "    raise NotImplementedError('Please complete fit_second_stage function')\n",
    "    ############################################################################\n",
    "\n",
    "    # Initialize linear regression model\n",
    "    stage2 = LinearRegression(...)\n",
    "\n",
    "    # Fit model to data\n",
    "    stage2.fit(...)\n",
    "\n",
    "    return stage2.coef_\n",
    "\n",
    "\n",
    "# Uncomment below to test your function\n",
    "# T_hat = fit_first_stage(T, Z)\n",
    "# beta = fit_second_stage(T_hat, Y)\n",
    "# print(\"Estimated causal effect is: {:.3f}\".format(beta[0, 0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.913662Z",
     "iopub.status.busy": "2021-05-25T01:25:08.912383Z",
     "iopub.status.idle": "2021-05-25T01:25:08.915560Z",
     "shell.execute_reply": "2021-05-25T01:25:08.915037Z"
    },
    "outputId": "96b05157-e059-4672-bbec-4986f29ace8d"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial4_Solution_b764d30b.py)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "You should obtain an estimated causal effect of `-1.984`. This is quite close to the true causal effect of $-2$!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 2: IVs in our simulated neural system\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 517
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.920953Z",
     "iopub.status.busy": "2021-05-25T01:25:08.920366Z",
     "iopub.status.idle": "2021-05-25T01:25:08.967720Z",
     "shell.execute_reply": "2021-05-25T01:25:08.968194Z"
    },
    "outputId": "d34ab471-4b8f-48b3-fcd6-5eec8ac0a9b5"
   },
   "outputs": [],
   "source": [
    "#@title Video 4: IVs in simulated neural systems\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"b6a3Mrefk44\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Now, say we have the neural system we have been simulating, except with an additional variable $\\vec{z}$. This will be our instrumental variable. \n",
    "\n",
    "We treat $\\vec{z}$ as a source of noise in the dynamics of our neurons:\n",
    "\n",
    "$$\n",
    "\\vec{x}_{t+1} = \\sigma(A\\vec{x}_t + \\eta \\vec{z}_{t+1} + \\epsilon_t)\n",
    "$$\n",
    "\n",
    "- $\\eta$ is what we'll call the \"strength\" of our IV\n",
    "- $\\vec{z}_t$ is a random binary variable, $\\vec{z}_t \\sim Bernoulli(0.5)$\n",
    "\n",
    "Remember that for each neuron $i$, we are trying to figure out whether $i$ is connected to (causally affects) the other neurons in our system *at the next time step*. So for timestep $t$, we want to determine whether $\\vec{x}_{i,t}$ affects all the other neurons at $\\vec{x}_{t+1}$. For a given neuron $i$, $\\vec{z}_{i,t}$ satistfies the 3 criteria for a valid instrument. \n",
    "\n",
    "\n",
    "**What could $z$ be, biologically?**\n",
    "\n",
    "Imagine $z$ to be some injected current through an *in vivo* patch clamp. It affects each neuron individually, and only affects dynamics through that neuron.\n",
    "\n",
    "The cool thing about IV is that you don't have to control $z$ yourself - it can be observed. So if you mess up your wiring and accidentally connect the injected voltage to an AM radio, no worries. As long as you can observe the signal the method will work."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "![neuron_iv.png]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Exercise 3: Simulate a system with IV\n",
    "\n",
    "Here we'll modify the function that simulates the neural system, but this time make the update rule include the effect of the instrumental variable $z$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.977086Z",
     "iopub.status.busy": "2021-05-25T01:25:08.975784Z",
     "iopub.status.idle": "2021-05-25T01:25:08.977719Z",
     "shell.execute_reply": "2021-05-25T01:25:08.978174Z"
    }
   },
   "outputs": [],
   "source": [
    "def simulate_neurons_iv(n_neurons, timesteps, eta, random_state=42):\n",
    "    \"\"\"\n",
    "    Simulates a dynamical system for the specified number of neurons and timesteps.\n",
    "\n",
    "    Args:\n",
    "        n_neurons (int): the number of neurons in our system.\n",
    "        timesteps (int): the number of timesteps to simulate our system.\n",
    "        eta (float): the strength of the instrument\n",
    "        random_state (int): seed for reproducibility\n",
    "\n",
    "    Returns:\n",
    "        The tuple (A,X,Z) of the connectivity matrix, simulated system, and instruments.\n",
    "        - A has shape (n_neurons, n_neurons)\n",
    "        - X has shape (n_neurons, timesteps)\n",
    "        - Z has shape (n_neurons, timesteps)\n",
    "    \"\"\"\n",
    "    np.random.seed(random_state)\n",
    "    A = create_connectivity(n_neurons, random_state)\n",
    "\n",
    "    X = np.zeros((n_neurons, timesteps))\n",
    "    Z = np.random.choice([0, 1], size=(n_neurons, timesteps))\n",
    "    for t in range(timesteps - 1):\n",
    "\n",
    "      ############################################################################\n",
    "      ## Insert your code here to adjust the update rule to include the\n",
    "      ## instrumental variable.\n",
    "      ##  We've already created Z for you. (We need to return it to regress on it).\n",
    "      ##  Your task is to slice it appropriately. Don't forget eta.\n",
    "      ## Fill out function and remove\n",
    "      raise NotImplementedError('Complete simulate_neurons_iv function')\n",
    "      ############################################################################\n",
    "\n",
    "      IV_on_this_timestep = ...\n",
    "\n",
    "      X[:, t + 1] = sigmoid(A.dot(X[:, t]) + IV_on_this_timestep + np.random.multivariate_normal(np.zeros(n_neurons), np.eye(n_neurons)))\n",
    "\n",
    "    return A, X, Z\n",
    "\n",
    "# Parameters\n",
    "timesteps = 5000  # Simulate for 5000 timesteps.\n",
    "n_neurons = 100  # the size of our system\n",
    "eta = 2  # the strength of our instrument, higher is stronger\n",
    "\n",
    "\n",
    "# Uncomment below to test your function\n",
    "\n",
    "# Simulate our dynamical system for the given amount of time\n",
    "# A, X, Z = simulate_neurons_iv(n_neurons, timesteps, eta)\n",
    "# plot_neural_activity(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 431
    },
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:08.989590Z",
     "iopub.status.busy": "2021-05-25T01:25:08.988978Z",
     "iopub.status.idle": "2021-05-25T01:25:13.144894Z",
     "shell.execute_reply": "2021-05-25T01:25:13.144371Z"
    },
    "outputId": "c39ed351-d155-4897-beda-ae33de41aca4"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial4_Solution_bb98112d.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=558 height=414 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W3D5_NetworkCausality/static/W3D5_Tutorial4_Solution_bb98112d_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Section 2.1: Estimate IV for simulated neural system"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Since you just implemented two-stage least squares, we've provided the network implementation for you, with the function `get_iv_estimate_network()`. Now, let's see how our IV estimates do in recovering the connectivity matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:13.153262Z",
     "iopub.status.busy": "2021-05-25T01:25:13.151980Z",
     "iopub.status.idle": "2021-05-25T01:25:13.153888Z",
     "shell.execute_reply": "2021-05-25T01:25:13.154343Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_iv_estimate_network(X, Z):\n",
    "    \"\"\"\n",
    "    Estimates the connectivity matrix from 2-stage least squares regression\n",
    "    using an instrument.\n",
    "\n",
    "    Args:\n",
    "        X (np.ndarray): our simulated system of shape (n_neurons, timesteps)\n",
    "        Z (np.ndarray): our observed instruments of shape (n_neurons, timesteps)\n",
    "\n",
    "    Returns:\n",
    "\n",
    "        V (np.ndarray): the estimated connectivity matrix\n",
    "    \"\"\"\n",
    "    n_neurons = X.shape[0]\n",
    "    Y = X[:, 1:].transpose()\n",
    "\n",
    "    # apply inverse sigmoid transformation\n",
    "    Y = logit(Y)\n",
    "\n",
    "    # Stage 1: regress X on Z\n",
    "    stage1 = MultiOutputRegressor(LinearRegression(fit_intercept=True), n_jobs=-1)\n",
    "    stage1.fit(Z[:, :-1].transpose(), X[:, :-1].transpose())\n",
    "    X_hat = stage1.predict(Z[:, :-1].transpose())\n",
    "\n",
    "    # Stage 2: regress Y on X_hatI\n",
    "    stage2 = MultiOutputRegressor(LinearRegression(fit_intercept=True), n_jobs=-1)\n",
    "    stage2.fit(X_hat, Y)\n",
    "\n",
    "    # Get estimated effects\n",
    "    V = np.zeros((n_neurons, n_neurons))\n",
    "    for i, estimator in enumerate(stage2.estimators_):\n",
    "        V[i, :] = estimator.coef_\n",
    "\n",
    "    return V"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Now let's see how well it works in our system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 352
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:13.161806Z",
     "iopub.status.busy": "2021-05-25T01:25:13.161152Z",
     "iopub.status.idle": "2021-05-25T01:25:15.622444Z",
     "shell.execute_reply": "2021-05-25T01:25:15.622976Z"
    },
    "outputId": "f23ccf76-9808-483d-e6b2-6d86c33f9ab8"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to visualize IV estimated connectivity matrix\n",
    "n_neurons = 6\n",
    "timesteps = 10000\n",
    "random_state = 42\n",
    "eta = 2\n",
    "\n",
    "A, X, Z = simulate_neurons_iv(n_neurons, timesteps, eta, random_state)\n",
    "V = get_iv_estimate_network(X, Z)\n",
    "print(\"IV estimated correlation: {:.3f}\".format(np.corrcoef(A.flatten(), V.flatten())[1, 0]))\n",
    "\n",
    "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
    "\n",
    "im = axs[0].imshow(A, cmap=\"coolwarm\")\n",
    "fig.colorbar(im, ax=axs[0],fraction=0.046, pad=0.04)\n",
    "axs[0].title.set_text(\"True connectivity matrix\")\n",
    "axs[0].set(xlabel='Connectivity from', ylabel='Connectivity to')\n",
    "\n",
    "im = axs[1].imshow(V, cmap=\"coolwarm\")\n",
    "fig.colorbar(im, ax=axs[1],fraction=0.046, pad=0.04)\n",
    "axs[1].title.set_text(\"IV estimated connectivity matrix\")\n",
    "axs[1].set(xlabel='Connectivity from')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "The IV estimates seem to perform pretty well! In the next section, we will see how they behave in the face of omitted variable bias."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Section 3: IVs and omitted variable bias \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 517
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:15.629377Z",
     "iopub.status.busy": "2021-05-25T01:25:15.628857Z",
     "iopub.status.idle": "2021-05-25T01:25:15.678030Z",
     "shell.execute_reply": "2021-05-25T01:25:15.677554Z"
    },
    "outputId": "fb738e85-3553-422b-c177-16d9fa91cb04"
   },
   "outputs": [],
   "source": [
    "#@title Video 5: IV vs regression\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"zceWyoQn09s\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Interactive Demo: Estimating connectivity with IV vs regression on a subset of observed neurons\n",
    "\n",
    "Change the ratio of observed neurons and look at the impact on the quality of connectivity estimation using IV vs regression. Which method does better with fewer observed neurons?\n",
    "\n",
    "**NOTE:** this simulation will take about a minute to run!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 363,
     "referenced_widgets": [
      "f47f8abc88714432a420356abe6b621a",
      "b92d7557f0cd4bd0b6bec001fd20186b",
      "c94e8aa5c7184899881d842b4ca9c5f5",
      "19a60e1994f94846911c962401b29643",
      "b55b6f949bec411e88e60324512115b0",
      "85159b65f825419ba2cb703aae7fe9e0",
      "5ee12cb064c447e19dd87386275ae00a"
     ]
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:15.697009Z",
     "iopub.status.busy": "2021-05-25T01:25:15.696386Z",
     "iopub.status.idle": "2021-05-25T01:25:19.436169Z",
     "shell.execute_reply": "2021-05-25T01:25:19.435675Z"
    },
    "outputId": "10456272-28c2-43fb-f45f-2650133f6351"
   },
   "outputs": [],
   "source": [
    "#@title\n",
    "#@markdown Execute this cell to enable demo\n",
    "n_neurons = 30\n",
    "timesteps = 20000\n",
    "random_state = 42\n",
    "eta = 2\n",
    "A, X, Z = simulate_neurons_iv(n_neurons, timesteps, eta, random_state)\n",
    "\n",
    "\n",
    "reg_args = {\n",
    "    \"fit_intercept\": False,\n",
    "    \"alpha\": 0.001\n",
    "}\n",
    "\n",
    "def get_regression_estimate_full_connectivity(X):\n",
    "    \"\"\"\n",
    "    Estimates the connectivity matrix using lasso regression.\n",
    "\n",
    "    Args:\n",
    "        X (np.ndarray): our simulated system of shape (n_neurons, timesteps)\n",
    "        neuron_idx (int): optionally provide a neuron idx to compute connectivity for\n",
    "    Returns:\n",
    "        V (np.ndarray): estimated connectivity matrix of shape (n_neurons, n_neurons).\n",
    "                        if neuron_idx is specified, V is of shape (n_neurons,).\n",
    "    \"\"\"\n",
    "    n_neurons = X.shape[0]\n",
    "\n",
    "    # Extract Y and W as defined above\n",
    "    W = X[:, :-1].transpose()\n",
    "    Y = X[:, 1:].transpose()\n",
    "\n",
    "    # apply inverse sigmoid transformation\n",
    "    Y = logit(Y)\n",
    "\n",
    "    # fit multioutput regression\n",
    "    reg = MultiOutputRegressor(Lasso(fit_intercept=False, alpha=0.01, max_iter=200), n_jobs=-1)\n",
    "    reg.fit(W, Y)\n",
    "\n",
    "    V = np.zeros((n_neurons, n_neurons))\n",
    "    for i, estimator in enumerate(reg.estimators_):\n",
    "        V[i, :] = estimator.coef_\n",
    "\n",
    "    return V\n",
    "\n",
    "\n",
    "def get_regression_corr_full_connectivity(n_neurons, A, X, observed_ratio, regression_args):\n",
    "    \"\"\"\n",
    "    A wrapper function for our correlation calculations between A and the V estimated\n",
    "    from regression.\n",
    "\n",
    "    Args:\n",
    "        n_neurons (int): number of neurons\n",
    "        A (np.ndarray): connectivity matrix\n",
    "        X (np.ndarray): dynamical system\n",
    "        observed_ratio (float): the proportion of n_neurons observed, must be betweem 0 and 1.\n",
    "        regression_args (dict): dictionary of lasso regression arguments and hyperparameters\n",
    "\n",
    "    Returns:\n",
    "        A single float correlation value representing the similarity between A and R\n",
    "    \"\"\"\n",
    "    assert (observed_ratio > 0) and (observed_ratio <= 1)\n",
    "\n",
    "    sel_idx = np.clip(int(n_neurons*observed_ratio), 1, n_neurons)\n",
    "\n",
    "    sel_X = X[:sel_idx, :]\n",
    "    sel_A = A[:sel_idx, :sel_idx]\n",
    "\n",
    "    sel_V = get_regression_estimate_full_connectivity(sel_X)\n",
    "    return np.corrcoef(sel_A.flatten(), sel_V.flatten())[1,0], sel_V\n",
    "\n",
    "\n",
    "@widgets.interact\n",
    "def plot_observed(ratio=[0.2, 0.4, 0.6, 0.8, 1.0]):\n",
    "  fig, axs = plt.subplots(1, 3, figsize=(15, 5))\n",
    "  sel_idx = int(ratio * n_neurons)\n",
    "  n_observed = sel_idx\n",
    "  offset = np.zeros((n_neurons, n_neurons))\n",
    "  offset[:sel_idx, :sel_idx] =  1 + A[:sel_idx, :sel_idx]\n",
    "  im = axs[0].imshow(offset, cmap=\"coolwarm\", vmin=0, vmax=A.max() + 1)\n",
    "  axs[0].title.set_text(\"True connectivity\")\n",
    "  axs[0].set_xlabel(\"Connectivity to\")\n",
    "  axs[0].set_ylabel(\"Connectivity from\")\n",
    "  plt.colorbar(im, ax=axs[0],fraction=0.046, pad=0.04)\n",
    "\n",
    "  sel_A = A[:sel_idx, :sel_idx]\n",
    "  sel_X = X[:sel_idx, :]\n",
    "  sel_Z = Z[:sel_idx, :]\n",
    "\n",
    "  V = get_iv_estimate_network(sel_X, sel_Z)\n",
    "  iv_corr = np.corrcoef(sel_A.flatten(), V.flatten())[1, 0]\n",
    "\n",
    "  big_V = np.zeros(A.shape)\n",
    "  big_V[:sel_idx, :sel_idx] =  1 + V\n",
    "\n",
    "  im = axs[1].imshow(big_V, cmap=\"coolwarm\", vmin=0, vmax=A.max() + 1)\n",
    "  plt.colorbar(im, ax=axs[1], fraction=0.046, pad=0.04)\n",
    "  c = 'w' if n_observed < (n_neurons - 3) else 'k'\n",
    "  axs[1].text(0,n_observed + 2, \"Correlation : {:.2f}\".format(iv_corr), color=c, size=15)\n",
    "  axs[1].axis(\"off\")\n",
    "\n",
    "\n",
    "  reg_corr, R = get_regression_corr_full_connectivity(n_neurons,\n",
    "                                  A,\n",
    "                                  X,\n",
    "                                  ratio,\n",
    "                                  reg_args)\n",
    "\n",
    "\n",
    "  big_R = np.zeros(A.shape)\n",
    "  big_R[:sel_idx, :sel_idx] =  1 + R\n",
    "\n",
    "  im = axs[2].imshow(big_R, cmap=\"coolwarm\", vmin=0, vmax=A.max() + 1)\n",
    "  plt.colorbar(im, ax=axs[2], fraction=0.046, pad=0.04)\n",
    "  c = 'w' if n_observed<(n_neurons-3) else 'k'\n",
    "  axs[1].title.set_text(\"Estimated connectivity (IV)\")\n",
    "  axs[1].set_xlabel(\"Connectivity to\")\n",
    "  axs[1].set_ylabel(\"Connectivity from\")\n",
    "\n",
    "  axs[2].text(0, n_observed + 2,\"Correlation : {:.2f}\".format(reg_corr), color=c, size=15)\n",
    "  axs[2].axis(\"off\")\n",
    "  axs[2].title.set_text(\"Estimated connectivity (regression)\")\n",
    "  axs[2].set_xlabel(\"Connectivity to\")\n",
    "  axs[2].set_ylabel(\"Connectivity from\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We can also visualize the performance of regression and IV as a function of the observed neuron ratio below.\n",
    "\n",
    "**Note** that this code takes about a minute to run!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 518
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:19.444330Z",
     "iopub.status.busy": "2021-05-25T01:25:19.443098Z",
     "iopub.status.idle": "2021-05-25T01:25:59.325381Z",
     "shell.execute_reply": "2021-05-25T01:25:59.324895Z"
    },
    "outputId": "37274be5-e39f-40f9-c0d9-bd6cddb41ad8"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to visualize connectivity estimation performance\n",
    "n_neurons = 40  # the size of the system\n",
    "timesteps = 20000\n",
    "random_state = 42\n",
    "eta = 2  # the strength of our instrument\n",
    "\n",
    "A, X, Z = simulate_neurons_iv(n_neurons, timesteps, eta, random_state)\n",
    "\n",
    "observed_ratio = [1, 0.8, 0.6, 0.4, 0.2]\n",
    "\n",
    "compare_iv_estimate_to_regression(observed_ratio)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We see that IVs handle omitted variable bias (when the instrument is strong and we have enough data).\n",
    "\n",
    "### The costs of IV analysis\n",
    "\n",
    "- we need to find an appropriate and valid instrument\n",
    "- Because of the 2-stage estimation process, we need strong instruments or else our standard errors will be large"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Discussion questions\n",
    "\n",
    "\n",
    "*   Think back to your most recent work. Can you create a causal diagram of the fundamental question? Are there sources of bias (omitted variables or otherwise) that might be a threat to causal validity?\n",
    "*   Can you think of any possibilities for instrumental variables? What sources of observed randomness could studies in your field leverage in identifying causal effects?\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Summary\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 517
    },
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:59.332108Z",
     "iopub.status.busy": "2021-05-25T01:25:59.331530Z",
     "iopub.status.idle": "2021-05-25T01:25:59.378940Z",
     "shell.execute_reply": "2021-05-25T01:25:59.379485Z"
    },
    "outputId": "b1284046-80da-422b-dd41-6edec4f5727a"
   },
   "outputs": [],
   "source": [
    "#@title Video 6: Summary\n",
    "# Insert the ID of the corresponding youtube video\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"1qxW8CPW77U\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtu.be/\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "In this tutorial, we:\n",
    "\n",
    "* Explored instrumental variables and how we can use them for causality estimates\n",
    "* Compared IV estimates to regression estimates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "# Appendix"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## (Bonus) Exercise: Exploring instrument strength\n",
    "\n",
    "Explore how the strength of the instrument $\\eta$ affects the quality of estimates with instrumental variables. \n",
    "**bold text**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:59.387391Z",
     "iopub.status.busy": "2021-05-25T01:25:59.386163Z",
     "iopub.status.idle": "2021-05-25T01:25:59.388037Z",
     "shell.execute_reply": "2021-05-25T01:25:59.388462Z"
    }
   },
   "outputs": [],
   "source": [
    "def instrument_strength_effect(etas, n_neurons, timesteps, n_trials):\n",
    "  \"\"\" Compute IV estimation performance for different instrument strengths\n",
    "\n",
    "  Args:\n",
    "    etas (list): different instrument strengths to compare\n",
    "    n_neurons (int): number of neurons in simulation\n",
    "    timesteps (int): number of timesteps in simulation\n",
    "    n_trials (int): number of trials to compute\n",
    "\n",
    "  Returns:\n",
    "    ndarray: n_trials x len(etas) array where each element is the correlation\n",
    "        between true and estimated connectivity matries for that trial and\n",
    "        instrument strength\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize corr array\n",
    "  corr_data = np.zeros((n_trials, len(etas)))\n",
    "\n",
    "  # Loop over trials\n",
    "  for trial in range(n_trials):\n",
    "      print(\"Trial {} of {}\".format(trial + 1, n_trials))\n",
    "\n",
    "      # Loop over instrument strenghs\n",
    "      for j, eta in enumerate(etas):\n",
    "          ########################################################################\n",
    "          ## TODO: Simulate system with a given instrument strength, get IV estimate,\n",
    "          ## and compute correlation\n",
    "          # Fill out function and remove\n",
    "          raise NotImplementedError('Student exercise: complete instrument_strength_effect')\n",
    "          ########################################################################\n",
    "\n",
    "          # Simulate system\n",
    "          A, X, Z = simulate_neurons_iv(...)\n",
    "\n",
    "          # Compute IV estimate\n",
    "          iv_V = get_iv_estimate_network(...)\n",
    "\n",
    "          # Compute correlation\n",
    "          corr_data[trial, j] =  np.corrcoef(A.flatten(), iv_V.flatten())[1, 0]\n",
    "\n",
    "  return corr_data\n",
    "\n",
    "# Parameters of system\n",
    "n_neurons = 20\n",
    "timesteps = 10000\n",
    "n_trials = 3\n",
    "etas = [2, 1, 0.5, 0.25, 0.12]  # instrument strengths to search over\n",
    "\n",
    "# Uncomment below to test your function\n",
    "#corr_data = instrument_strength_effect(etas, n_neurons, timesteps, n_trials)\n",
    "#plot_performance_vs_eta(etas, corr_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:25:59.395587Z",
     "iopub.status.busy": "2021-05-25T01:25:59.395015Z",
     "iopub.status.idle": "2021-05-25T01:26:24.890590Z",
     "shell.execute_reply": "2021-05-25T01:26:24.890110Z"
    },
    "outputId": "5c85039d-1cb4-4729-92b6-44950e40a849"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial4_Solution_8e177c98.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=560 height=416 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W3D5_NetworkCausality/static/W3D5_Tutorial4_Solution_8e177c98_3.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "---\n",
    "\n",
    "# (Bonus) Section 4: Granger Causality\n",
    "\n",
    "**Please revisit this section AFTER you complete tutorial notebook 4**, if you have time.\n",
    "\n",
    "Another potential solution to temporal causation that we might consider: [*Granger Causality*](https://en.wikipedia.org/wiki/Granger_causality).\n",
    "\n",
    "But, like the simultaneous fitting we explored in Tutorial 3, this method still fails in the presence of unobserved variables.\n",
    "\n",
    "We are testing whether a time series $X$ Granger-causes a time series $Y$ through a hypothesis test:\n",
    "\n",
    "- the null hypothesis $H_0$: lagged values of $X$ do not help predict values of $Y$ \n",
    "\n",
    "- the alternative hypothesis $H_a$: lagged values of $X$ **do** help predict values of $Y$ \n",
    "\n",
    "Mechanically, this is accomplished by fitting autoregressive models for $y_{t}$. We fail to reject the hypothesis if none of the $x_{t-k}$ terms are retained as significant in the regression. For simplicity, we will consider only one time lag. So, we have:\n",
    "\n",
    "$$\n",
    "H_0: y_t = a_0 + a_1 y_{t-1} +\\epsilon_t\n",
    "$$\n",
    "\n",
    "$$\n",
    "H_a: y_t = a_0 + a_1 y_{t-1} + b_1 x_{t-1} +\\epsilon_t\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:26:24.894622Z",
     "iopub.status.busy": "2021-05-25T01:26:24.894075Z",
     "iopub.status.idle": "2021-05-25T01:26:33.409277Z",
     "shell.execute_reply": "2021-05-25T01:26:33.408560Z"
    },
    "outputId": "8d88a260-2847-43e2-bc31-c2994f7820f3"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to get custom imports from stats models\n",
    "# we need custom imports from stats models\n",
    "!pip install statsmodels\n",
    "from statsmodels.tsa.stattools import grangercausalitytests"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Granger causality in small systems\n",
    "\n",
    "We will first evaluate Granger causality in a small system.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "\n",
    "### (Bonus) Exercise: Evaluate Granger causality\n",
    "\n",
    "Complete the following definition to evaluate the Granger causality between our neurons. Then run the cells below to evaluate how well it works. You will use the `grangercausalitytests()` function already imported from statsmodels. We will then check whether a neuron in a small system Granger-causes the others.\n",
    "\n",
    "### Suggestions\n",
    "\n",
    "- use `help()` to check the function signature for `grangercausalitytests()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:26:33.417861Z",
     "iopub.status.busy": "2021-05-25T01:26:33.416686Z",
     "iopub.status.idle": "2021-05-25T01:26:33.923087Z",
     "shell.execute_reply": "2021-05-25T01:26:33.922206Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_granger_causality(X, selected_neuron, alpha=0.05):\n",
    "    \"\"\"\n",
    "    Estimates the lag-1 granger causality of the given neuron on the other neurons in the system.\n",
    "\n",
    "    Args:\n",
    "        X (np.ndarray): the matrix holding our dynamical system of shape (n_neurons, timesteps)\n",
    "        selected_neuron (int): the index of the neuron we want to estimate granger causality for\n",
    "        alpha (float): Bonferroni multiple comparisons correction\n",
    "\n",
    "    Returns:\n",
    "        A tuple (reject_null, p_vals)\n",
    "        reject_null (list): a binary list of length n_neurons whether the null was\n",
    "            rejected for the selected neuron granger causing the other neurons\n",
    "        p_vals (list): a list of the p-values for the corresponding Granger causality tests\n",
    "    \"\"\"\n",
    "    n_neurons = X.shape[0]\n",
    "    max_lag = 1\n",
    "\n",
    "    reject_null = []\n",
    "    p_vals = []\n",
    "\n",
    "    for target_neuron in range(n_neurons):\n",
    "        ts_data = X[[target_neuron, selected_neuron], :].transpose()\n",
    "\n",
    "        ########################################################################\n",
    "        ## Insert your code here to run Granger causality tests.\n",
    "        ##\n",
    "        ## Function Hints:\n",
    "        ##     Pass the ts_data defined above as the first argument\n",
    "        ##     Granger causality -> grangercausalitytests\n",
    "        ## Fill out this function and then remove\n",
    "        raise NotImplementedError('Student exercise: complete get_granger_causality function')\n",
    "        ########################################################################\n",
    "        res = grangercausalitytests(...)\n",
    "\n",
    "        # Gets the p-value for the log-ratio test\n",
    "        pval = res[1][0]['lrtest'][1]\n",
    "\n",
    "        p_vals.append(pval)\n",
    "        reject_null.append(int(pval < alpha))\n",
    "\n",
    "    return reject_null, p_vals\n",
    "\n",
    "\n",
    "# Set up small system\n",
    "n_neurons = 6\n",
    "timesteps = 5000\n",
    "random_state = 42\n",
    "selected_neuron = 1\n",
    "\n",
    "A = create_connectivity(n_neurons, random_state)\n",
    "X = simulate_neurons(A, timesteps, random_state)\n",
    "\n",
    "# Uncomment below to test your function\n",
    "# reject_null, p_vals = get_granger_causality(X, selected_neuron)\n",
    "\n",
    "# compare_granger_connectivity(A, reject_null, selected_neuron)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {},
    "colab_type": "text",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:26:33.930367Z",
     "iopub.status.busy": "2021-05-25T01:26:33.929841Z",
     "iopub.status.idle": "2021-05-25T01:26:35.014190Z",
     "shell.execute_reply": "2021-05-25T01:26:35.013716Z"
    },
    "outputId": "2fa4881f-37b2-4f00-da0a-348e02ae51ec"
   },
   "source": [
    "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/master//tutorials/W3D5_NetworkCausality/solutions/W3D5_Tutorial4_Solution_f1e1f75e.py)\n",
    "\n",
    "*Example output:*\n",
    "\n",
    "<img alt='Solution hint' align='left' width=693 height=341 src=https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/tutorials/W3D5_NetworkCausality/static/W3D5_Tutorial4_Solution_f1e1f75e_0.png>\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Looks good! Let's also check the correlation between Granger estimates and the true connectivity.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:26:35.019550Z",
     "iopub.status.busy": "2021-05-25T01:26:35.018956Z",
     "iopub.status.idle": "2021-05-25T01:26:35.023289Z",
     "shell.execute_reply": "2021-05-25T01:26:35.022804Z"
    },
    "outputId": "f5d32be3-f40a-4f03-c663-72975f3ccca7"
   },
   "outputs": [],
   "source": [
    "np.corrcoef(A[:,selected_neuron], np.array(reject_null))[1, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "When we have a small system, we correctly identify the causality of neuron 1."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Granger causality in large systems"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We will now run Granger causality on a large system with 100 neurons. Does it still work well? How does the number of timesteps matter?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:26:35.027772Z",
     "iopub.status.busy": "2021-05-25T01:26:35.027199Z",
     "iopub.status.idle": "2021-05-25T01:26:39.377940Z",
     "shell.execute_reply": "2021-05-25T01:26:39.377492Z"
    },
    "outputId": "dab5aa14-582e-405d-8b46-a7b892754c71"
   },
   "outputs": [],
   "source": [
    "#@markdown Execute this cell to examine Granger causality in a large system\n",
    "n_neurons = 100\n",
    "timesteps = 5000\n",
    "random_state = 42\n",
    "selected_neuron = 1\n",
    "A = create_connectivity(n_neurons, random_state)\n",
    "X = simulate_neurons(A, timesteps, random_state)\n",
    "\n",
    "# get granger causality estimates\n",
    "reject_null, p_vals = get_granger_causality(X, selected_neuron)\n",
    "compare_granger_connectivity(A, reject_null, selected_neuron)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Let's again check the correlation between the Granger estimates and the true connectivity. Are we able to recover the true connectivity well in this larger system?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:26:39.383623Z",
     "iopub.status.busy": "2021-05-25T01:26:39.383027Z",
     "iopub.status.idle": "2021-05-25T01:26:39.385597Z",
     "shell.execute_reply": "2021-05-25T01:26:39.386018Z"
    },
    "outputId": "4b93542c-b09f-44d5-d3f0-8344ea0a9674"
   },
   "outputs": [],
   "source": [
    "np.corrcoef(A[:,selected_neuron], np.array(reject_null))[1, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Notes on Granger Causality\n",
    "\n",
    "Here we considered bivariate Granger causality -- for each pair of neurons $A, B$, does one Granger-cause the other? You might wonder whether considering more variables will help with estimation. *Conditional Granger Causality* is a technique that allows for a multivariate system, where we test whether $A$ Granger-causes $B$ conditional on the other variables in the system. \n",
    "\n",
    "Even after controlling for variables in the system, conditional Granger causality will also likely perform poorly as our system gets larger. Plus, measuring the additional variables to condition on may be infeasible in practical applications, which would introduce omitted variable bias as we saw in the regression exercise.\n",
    "\n",
    "One takeaway here is that as our estimation procedures become more sophisticated, they also become more difficult to interpret. We always need to understand the methods and the assumptions that are made."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W3D5_Tutorial4",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
